{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2021-2022 @ Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd\n",
    "\n",
    "This code is a part of Cybertron package.\n",
    "\n",
    "The Cybertron is open-source software based on the AI-framework:\n",
    "MindSpore (https://www.mindspore.cn/)\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "\n",
    "You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n",
    "\n",
    "Cybertron tutorial 10: Run MD simulation in with CybertronFF as potential"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(27387:139671448667968,MainProcess):2022-08-10-17:53:18.841.538 [mindspore/run_check/_check_version.py:137] Can not found cuda libs, please confirm that the correct cuda version has been installed, you can refer to the installation guidelines: https://www.mindspore.cn/install\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "from mindspore import load_checkpoint\n",
    "from mindspore import context\n",
    "\n",
    "from mindsponge import Molecule\n",
    "from mindsponge import Sponge\n",
    "from mindsponge import set_global_units\n",
    "from mindsponge.callback import RunInfo, WriteH5MD\n",
    "from mindsponge.control import LeapFrog\n",
    "from mindsponge.control import Langevin\n",
    "from mindsponge.optimizer import DynamicUpdater\n",
    "\n",
    "from cybertron.model import MolCT\n",
    "from cybertron.readout import AtomwiseReadout\n",
    "from cybertron.cybertron import CybertronFF\n",
    "\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"GPU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_global_units('A', 'kcal/mol')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "atom_types = np.array(\n",
    "    [[6, 1, 1, 1, 6, 8, 6, 8, 6, 8, 6, 8, 6, 8, 25]], np.int32)\n",
    "coordinate = np.array([\n",
    "    [0.782936, -0.21384, 1.940403],\n",
    "    [0.90026, -1.258313, 2.084498],\n",
    "    [1.793443, 0.267702, 1.791434],\n",
    "    [0.161631, 0.247471, 2.702921],\n",
    "    [-1.775807, 0.660242, 0.992526],\n",
    "    [-2.573144, 0.82639, 1.806692],\n",
    "    [-0.793238, 0.551875, -1.559148],\n",
    "    [-0.922246, 0.719072, -2.702972],\n",
    "    [1.526357, -0.229486, -0.35567],\n",
    "    [2.624975, -0.473657, -0.641924],\n",
    "    [-0.786405, -1.533853, -0.007962],\n",
    "    [-1.266142, -2.537492, 0.254628],\n",
    "    [0.394547, 1.910025, 0.468161],\n",
    "    [0.747036, 3.027445, 0.458565],\n",
    "    [-0.163356, 0.241977, 0.175396],\n",
    "])\n",
    "system = Molecule(atomic_number=atom_types, coordinate=coordinate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = MolCT(\n",
    "    dim_feature=128,\n",
    "    num_atom_types=100,\n",
    "    n_interaction=3,\n",
    "    n_heads=8,\n",
    "    max_cycles=1,\n",
    "    cutoff=10,\n",
    "    fixed_cycles=True,\n",
    "    length_unit='A',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "readout = AtomwiseReadout(\n",
    "    model=mod,\n",
    "    dim_output=1,\n",
    "    activation=mod.activation,\n",
    "    scale=0.5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "potential = CybertronFF(\n",
    "    model=mod,\n",
    "    readout=readout,\n",
    "    atom_types=atom_types,\n",
    "    length_unit='A',\n",
    "    energy_unit='kcal/mol',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'model.atom_embedding.embedding_table': Parameter (name=model.atom_embedding.embedding_table, shape=(100, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.dis_filter.linear.weight': Parameter (name=model.dis_filter.linear.weight, shape=(128, 64), dtype=Float32, requires_grad=True),\n",
       " 'model.dis_filter.linear.bias': Parameter (name=model.dis_filter.linear.bias, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'model.dis_filter.residual.nonlinear.mlp.0.weight': Parameter (name=model.dis_filter.residual.nonlinear.mlp.0.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.dis_filter.residual.nonlinear.mlp.0.bias': Parameter (name=model.dis_filter.residual.nonlinear.mlp.0.bias, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'model.dis_filter.residual.nonlinear.mlp.1.weight': Parameter (name=model.dis_filter.residual.nonlinear.mlp.1.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.dis_filter.residual.nonlinear.mlp.1.bias': Parameter (name=model.dis_filter.residual.nonlinear.mlp.1.bias, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.0.positional_embedding.norm.gamma': Parameter (name=model.interactions.0.positional_embedding.norm.gamma, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.0.positional_embedding.norm.beta': Parameter (name=model.interactions.0.positional_embedding.norm.beta, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.0.positional_embedding.x2q.weight': Parameter (name=model.interactions.0.positional_embedding.x2q.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.0.positional_embedding.x2k.weight': Parameter (name=model.interactions.0.positional_embedding.x2k.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.0.positional_embedding.x2v.weight': Parameter (name=model.interactions.0.positional_embedding.x2v.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.0.multi_head_attention.output.weight': Parameter (name=model.interactions.0.multi_head_attention.output.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.1.positional_embedding.norm.gamma': Parameter (name=model.interactions.1.positional_embedding.norm.gamma, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.1.positional_embedding.norm.beta': Parameter (name=model.interactions.1.positional_embedding.norm.beta, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.1.positional_embedding.x2q.weight': Parameter (name=model.interactions.1.positional_embedding.x2q.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.1.positional_embedding.x2k.weight': Parameter (name=model.interactions.1.positional_embedding.x2k.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.1.positional_embedding.x2v.weight': Parameter (name=model.interactions.1.positional_embedding.x2v.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.1.multi_head_attention.output.weight': Parameter (name=model.interactions.1.multi_head_attention.output.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.2.positional_embedding.norm.gamma': Parameter (name=model.interactions.2.positional_embedding.norm.gamma, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.2.positional_embedding.norm.beta': Parameter (name=model.interactions.2.positional_embedding.norm.beta, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.2.positional_embedding.x2q.weight': Parameter (name=model.interactions.2.positional_embedding.x2q.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.2.positional_embedding.x2k.weight': Parameter (name=model.interactions.2.positional_embedding.x2k.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.2.positional_embedding.x2v.weight': Parameter (name=model.interactions.2.positional_embedding.x2v.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'model.interactions.2.multi_head_attention.output.weight': Parameter (name=model.interactions.2.multi_head_attention.output.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'readout.decoder.output.mlp.0.weight': Parameter (name=readout.decoder.output.mlp.0.weight, shape=(64, 128), dtype=Float32, requires_grad=True),\n",
       " 'readout.decoder.output.mlp.0.bias': Parameter (name=readout.decoder.output.mlp.0.bias, shape=(64,), dtype=Float32, requires_grad=True),\n",
       " 'readout.decoder.output.mlp.1.weight': Parameter (name=readout.decoder.output.mlp.1.weight, shape=(1, 64), dtype=Float32, requires_grad=True),\n",
       " 'readout.decoder.output.mlp.1.bias': Parameter (name=readout.decoder.output.mlp.1.bias, shape=(1,), dtype=Float32, requires_grad=True),\n",
       " 'step': Parameter (name=step, shape=(), dtype=Int32, requires_grad=True),\n",
       " 'global_step': Parameter (name=global_step, shape=(1,), dtype=Int32, requires_grad=True),\n",
       " 'beta1_power': Parameter (name=beta1_power, shape=(1,), dtype=Float32, requires_grad=True),\n",
       " 'beta2_power': Parameter (name=beta2_power, shape=(1,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.atom_embedding.embedding_table': Parameter (name=moment1.model.atom_embedding.embedding_table, shape=(100, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.dis_filter.linear.weight': Parameter (name=moment1.model.dis_filter.linear.weight, shape=(128, 64), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.dis_filter.linear.bias': Parameter (name=moment1.model.dis_filter.linear.bias, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.dis_filter.residual.nonlinear.mlp.0.weight': Parameter (name=moment1.model.dis_filter.residual.nonlinear.mlp.0.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.dis_filter.residual.nonlinear.mlp.0.bias': Parameter (name=moment1.model.dis_filter.residual.nonlinear.mlp.0.bias, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.dis_filter.residual.nonlinear.mlp.1.weight': Parameter (name=moment1.model.dis_filter.residual.nonlinear.mlp.1.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.dis_filter.residual.nonlinear.mlp.1.bias': Parameter (name=moment1.model.dis_filter.residual.nonlinear.mlp.1.bias, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.0.positional_embedding.norm.gamma': Parameter (name=moment1.model.interactions.0.positional_embedding.norm.gamma, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.0.positional_embedding.norm.beta': Parameter (name=moment1.model.interactions.0.positional_embedding.norm.beta, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.0.positional_embedding.x2q.weight': Parameter (name=moment1.model.interactions.0.positional_embedding.x2q.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.0.positional_embedding.x2k.weight': Parameter (name=moment1.model.interactions.0.positional_embedding.x2k.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.0.positional_embedding.x2v.weight': Parameter (name=moment1.model.interactions.0.positional_embedding.x2v.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.0.multi_head_attention.output.weight': Parameter (name=moment1.model.interactions.0.multi_head_attention.output.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.1.positional_embedding.norm.gamma': Parameter (name=moment1.model.interactions.1.positional_embedding.norm.gamma, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.1.positional_embedding.norm.beta': Parameter (name=moment1.model.interactions.1.positional_embedding.norm.beta, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.1.positional_embedding.x2q.weight': Parameter (name=moment1.model.interactions.1.positional_embedding.x2q.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.1.positional_embedding.x2k.weight': Parameter (name=moment1.model.interactions.1.positional_embedding.x2k.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.1.positional_embedding.x2v.weight': Parameter (name=moment1.model.interactions.1.positional_embedding.x2v.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.1.multi_head_attention.output.weight': Parameter (name=moment1.model.interactions.1.multi_head_attention.output.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.2.positional_embedding.norm.gamma': Parameter (name=moment1.model.interactions.2.positional_embedding.norm.gamma, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.2.positional_embedding.norm.beta': Parameter (name=moment1.model.interactions.2.positional_embedding.norm.beta, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.2.positional_embedding.x2q.weight': Parameter (name=moment1.model.interactions.2.positional_embedding.x2q.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.2.positional_embedding.x2k.weight': Parameter (name=moment1.model.interactions.2.positional_embedding.x2k.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.2.positional_embedding.x2v.weight': Parameter (name=moment1.model.interactions.2.positional_embedding.x2v.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.model.interactions.2.multi_head_attention.output.weight': Parameter (name=moment1.model.interactions.2.multi_head_attention.output.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.readout.decoder.output.mlp.0.weight': Parameter (name=moment1.readout.decoder.output.mlp.0.weight, shape=(64, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment1.readout.decoder.output.mlp.0.bias': Parameter (name=moment1.readout.decoder.output.mlp.0.bias, shape=(64,), dtype=Float32, requires_grad=True),\n",
       " 'moment1.readout.decoder.output.mlp.1.weight': Parameter (name=moment1.readout.decoder.output.mlp.1.weight, shape=(1, 64), dtype=Float32, requires_grad=True),\n",
       " 'moment1.readout.decoder.output.mlp.1.bias': Parameter (name=moment1.readout.decoder.output.mlp.1.bias, shape=(1,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.atom_embedding.embedding_table': Parameter (name=moment2.model.atom_embedding.embedding_table, shape=(100, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.dis_filter.linear.weight': Parameter (name=moment2.model.dis_filter.linear.weight, shape=(128, 64), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.dis_filter.linear.bias': Parameter (name=moment2.model.dis_filter.linear.bias, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.dis_filter.residual.nonlinear.mlp.0.weight': Parameter (name=moment2.model.dis_filter.residual.nonlinear.mlp.0.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.dis_filter.residual.nonlinear.mlp.0.bias': Parameter (name=moment2.model.dis_filter.residual.nonlinear.mlp.0.bias, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.dis_filter.residual.nonlinear.mlp.1.weight': Parameter (name=moment2.model.dis_filter.residual.nonlinear.mlp.1.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.dis_filter.residual.nonlinear.mlp.1.bias': Parameter (name=moment2.model.dis_filter.residual.nonlinear.mlp.1.bias, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.0.positional_embedding.norm.gamma': Parameter (name=moment2.model.interactions.0.positional_embedding.norm.gamma, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.0.positional_embedding.norm.beta': Parameter (name=moment2.model.interactions.0.positional_embedding.norm.beta, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.0.positional_embedding.x2q.weight': Parameter (name=moment2.model.interactions.0.positional_embedding.x2q.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.0.positional_embedding.x2k.weight': Parameter (name=moment2.model.interactions.0.positional_embedding.x2k.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.0.positional_embedding.x2v.weight': Parameter (name=moment2.model.interactions.0.positional_embedding.x2v.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.0.multi_head_attention.output.weight': Parameter (name=moment2.model.interactions.0.multi_head_attention.output.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.1.positional_embedding.norm.gamma': Parameter (name=moment2.model.interactions.1.positional_embedding.norm.gamma, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.1.positional_embedding.norm.beta': Parameter (name=moment2.model.interactions.1.positional_embedding.norm.beta, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.1.positional_embedding.x2q.weight': Parameter (name=moment2.model.interactions.1.positional_embedding.x2q.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.1.positional_embedding.x2k.weight': Parameter (name=moment2.model.interactions.1.positional_embedding.x2k.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.1.positional_embedding.x2v.weight': Parameter (name=moment2.model.interactions.1.positional_embedding.x2v.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.1.multi_head_attention.output.weight': Parameter (name=moment2.model.interactions.1.multi_head_attention.output.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.2.positional_embedding.norm.gamma': Parameter (name=moment2.model.interactions.2.positional_embedding.norm.gamma, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.2.positional_embedding.norm.beta': Parameter (name=moment2.model.interactions.2.positional_embedding.norm.beta, shape=(128,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.2.positional_embedding.x2q.weight': Parameter (name=moment2.model.interactions.2.positional_embedding.x2q.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.2.positional_embedding.x2k.weight': Parameter (name=moment2.model.interactions.2.positional_embedding.x2k.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.2.positional_embedding.x2v.weight': Parameter (name=moment2.model.interactions.2.positional_embedding.x2v.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.model.interactions.2.multi_head_attention.output.weight': Parameter (name=moment2.model.interactions.2.multi_head_attention.output.weight, shape=(128, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.readout.decoder.output.mlp.0.weight': Parameter (name=moment2.readout.decoder.output.mlp.0.weight, shape=(64, 128), dtype=Float32, requires_grad=True),\n",
       " 'moment2.readout.decoder.output.mlp.0.bias': Parameter (name=moment2.readout.decoder.output.mlp.0.bias, shape=(64,), dtype=Float32, requires_grad=True),\n",
       " 'moment2.readout.decoder.output.mlp.1.weight': Parameter (name=moment2.readout.decoder.output.mlp.1.weight, shape=(1, 64), dtype=Float32, requires_grad=True),\n",
       " 'moment2.readout.decoder.output.mlp.1.bias': Parameter (name=moment2.readout.decoder.output.mlp.1.bias, shape=(1,), dtype=Float32, requires_grad=True)}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_file = 'checkpoint_c10.ckpt'\n",
    "load_checkpoint(param_file, net=potential)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = DynamicUpdater(\n",
    "    system,\n",
    "    integrator=LeapFrog(system),\n",
    "    thermostat=Langevin(system, 300),\n",
    "    time_step=1e-4,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[58.57325]]\n"
     ]
    }
   ],
   "source": [
    "md = Sponge(system, potential, opt)\n",
    "print(md.energy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "cb_h5md = WriteH5MD(system, 'Tutorial_C10.h5md', save_freq=10)\n",
    "cb_sim = RunInfo(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step: 0, E_pot: 58.57325, E_kin: 0.0, E_tot: 58.57325, Temperature: 0.0\n",
      "Step: 10, E_pot: 58.536163, E_kin: 0.04148401, E_tot: 58.57765, Temperature: 0.9940745\n",
      "Step: 20, E_pot: 58.425972, E_kin: 0.17119402, E_tot: 58.597164, Temperature: 4.102294\n",
      "Step: 30, E_pot: 58.227722, E_kin: 0.38126412, E_tot: 58.608986, Temperature: 9.136169\n",
      "Step: 40, E_pot: 57.925835, E_kin: 0.69324285, E_tot: 58.619076, Temperature: 16.612064\n",
      "Step: 50, E_pot: 57.486073, E_kin: 1.1403638, E_tot: 58.62644, Temperature: 27.326351\n",
      "Step: 60, E_pot: 56.886585, E_kin: 1.773051, E_tot: 58.659637, Temperature: 42.48733\n",
      "Step: 70, E_pot: 56.16535, E_kin: 2.5168397, E_tot: 58.68219, Temperature: 60.310616\n",
      "Step: 80, E_pot: 55.449783, E_kin: 3.2099535, E_tot: 58.659737, Temperature: 76.91959\n",
      "Step: 90, E_pot: 54.85717, E_kin: 3.7514, E_tot: 58.60857, Temperature: 89.89419\n",
      "Step: 100, E_pot: 54.382294, E_kin: 4.1613407, E_tot: 58.543633, Temperature: 99.71752\n",
      "Step: 110, E_pot: 53.957165, E_kin: 4.5936527, E_tot: 58.55082, Temperature: 110.07694\n",
      "Step: 120, E_pot: 53.49989, E_kin: 5.0548677, E_tot: 58.554756, Temperature: 121.12897\n",
      "Step: 130, E_pot: 52.93885, E_kin: 5.5801845, E_tot: 58.519035, Temperature: 133.71706\n",
      "Step: 140, E_pot: 52.25129, E_kin: 6.108301, E_tot: 58.35959, Temperature: 146.37221\n",
      "Step: 150, E_pot: 51.47289, E_kin: 6.827892, E_tot: 58.30078, Temperature: 163.61565\n",
      "Step: 160, E_pot: 50.642105, E_kin: 7.6543865, E_tot: 58.296494, Temperature: 183.4208\n",
      "Step: 170, E_pot: 49.621086, E_kin: 8.739233, E_tot: 58.36032, Temperature: 209.41681\n",
      "Step: 180, E_pot: 48.25978, E_kin: 10.073587, E_tot: 58.333366, Temperature: 241.39172\n",
      "Step: 190, E_pot: 47.162857, E_kin: 11.145048, E_tot: 58.307907, Temperature: 267.06696\n",
      "Step: 200, E_pot: 46.160683, E_kin: 12.207607, E_tot: 58.36829, Temperature: 292.52887\n",
      "Step: 210, E_pot: 45.06707, E_kin: 13.227571, E_tot: 58.29464, Temperature: 316.9701\n",
      "Step: 220, E_pot: 44.00579, E_kin: 14.309448, E_tot: 58.31524, Temperature: 342.89496\n",
      "Step: 230, E_pot: 43.01633, E_kin: 15.133271, E_tot: 58.1496, Temperature: 362.63608\n",
      "Step: 240, E_pot: 42.090492, E_kin: 15.981769, E_tot: 58.07226, Temperature: 382.96848\n",
      "Step: 250, E_pot: 41.401764, E_kin: 16.67428, E_tot: 58.076042, Temperature: 399.56302\n",
      "Step: 260, E_pot: 41.092117, E_kin: 16.959068, E_tot: 58.051186, Temperature: 406.38736\n",
      "Step: 270, E_pot: 41.066536, E_kin: 16.94873, E_tot: 58.015266, Temperature: 406.13965\n",
      "Step: 280, E_pot: 41.16301, E_kin: 16.775509, E_tot: 57.93852, Temperature: 401.98874\n",
      "Step: 290, E_pot: 41.211586, E_kin: 16.881145, E_tot: 58.09273, Temperature: 404.5201\n",
      "Step: 300, E_pot: 41.092056, E_kin: 17.258114, E_tot: 58.35017, Temperature: 413.55334\n",
      "Step: 310, E_pot: 40.81884, E_kin: 17.575943, E_tot: 58.394783, Temperature: 421.16943\n",
      "Step: 320, E_pot: 40.36782, E_kin: 17.960974, E_tot: 58.328796, Temperature: 430.39584\n",
      "Step: 330, E_pot: 39.778458, E_kin: 18.308634, E_tot: 58.08709, Temperature: 438.72678\n",
      "Step: 340, E_pot: 39.19735, E_kin: 19.259539, E_tot: 58.456886, Temperature: 461.51315\n",
      "Step: 350, E_pot: 38.78582, E_kin: 19.462898, E_tot: 58.24872, Temperature: 466.38623\n",
      "Step: 360, E_pot: 38.41848, E_kin: 19.963326, E_tot: 58.381805, Temperature: 478.3779\n",
      "Step: 370, E_pot: 37.619656, E_kin: 20.688374, E_tot: 58.30803, Temperature: 495.7521\n",
      "Step: 380, E_pot: 36.158264, E_kin: 22.215466, E_tot: 58.37373, Temperature: 532.3456\n",
      "Step: 390, E_pot: 34.38509, E_kin: 24.02639, E_tot: 58.41148, Temperature: 575.7404\n",
      "Step: 400, E_pot: 32.73697, E_kin: 25.745998, E_tot: 58.482967, Temperature: 616.94714\n",
      "Step: 410, E_pot: 31.487047, E_kin: 27.02228, E_tot: 58.509327, Temperature: 647.53046\n",
      "Step: 420, E_pot: 30.616928, E_kin: 28.058525, E_tot: 58.675453, Temperature: 672.3618\n",
      "Step: 430, E_pot: 30.045885, E_kin: 28.636457, E_tot: 58.682343, Temperature: 686.2107\n",
      "Step: 440, E_pot: 29.64339, E_kin: 29.213367, E_tot: 58.85676, Temperature: 700.0351\n",
      "Step: 450, E_pot: 29.11487, E_kin: 29.827196, E_tot: 58.942066, Temperature: 714.7442\n",
      "Step: 460, E_pot: 28.339184, E_kin: 30.52714, E_tot: 58.866325, Temperature: 731.51685\n",
      "Step: 470, E_pot: 27.38255, E_kin: 31.40363, E_tot: 58.78618, Temperature: 752.52\n",
      "Step: 480, E_pot: 26.327927, E_kin: 32.43982, E_tot: 58.767746, Temperature: 777.3501\n",
      "Step: 490, E_pot: 25.566223, E_kin: 33.293724, E_tot: 58.859947, Temperature: 797.8121\n",
      "Step: 500, E_pot: 25.064247, E_kin: 33.85193, E_tot: 58.916176, Temperature: 811.18823\n",
      "Step: 510, E_pot: 24.823431, E_kin: 33.965824, E_tot: 58.789253, Temperature: 813.9175\n",
      "Step: 520, E_pot: 24.706427, E_kin: 34.233784, E_tot: 58.94021, Temperature: 820.3385\n",
      "Step: 530, E_pot: 24.51696, E_kin: 34.676964, E_tot: 59.193924, Temperature: 830.9584\n",
      "Step: 540, E_pot: 24.53661, E_kin: 34.47619, E_tot: 59.0128, Temperature: 826.1473\n",
      "Step: 550, E_pot: 24.790684, E_kin: 33.96901, E_tot: 58.759693, Temperature: 813.9938\n",
      "Step: 560, E_pot: 25.045074, E_kin: 33.84356, E_tot: 58.888634, Temperature: 810.9876\n",
      "Step: 570, E_pot: 25.229933, E_kin: 33.312225, E_tot: 58.54216, Temperature: 798.2554\n",
      "Step: 580, E_pot: 25.423386, E_kin: 32.950405, E_tot: 58.37379, Temperature: 789.58514\n",
      "Step: 590, E_pot: 25.833635, E_kin: 32.612152, E_tot: 58.445786, Temperature: 781.4797\n",
      "Step: 600, E_pot: 26.446373, E_kin: 31.89583, E_tot: 58.3422, Temperature: 764.3145\n",
      "Step: 610, E_pot: 27.13861, E_kin: 31.003372, E_tot: 58.141983, Temperature: 742.9287\n",
      "Step: 620, E_pot: 27.766739, E_kin: 30.23656, E_tot: 58.0033, Temperature: 724.5538\n",
      "Step: 630, E_pot: 28.139828, E_kin: 29.601048, E_tot: 57.740875, Temperature: 709.325\n",
      "Step: 640, E_pot: 28.29393, E_kin: 29.500893, E_tot: 57.794823, Temperature: 706.92505\n",
      "Step: 650, E_pot: 28.433273, E_kin: 28.869507, E_tot: 57.30278, Temperature: 691.7953\n",
      "Step: 660, E_pot: 28.525835, E_kin: 28.61742, E_tot: 57.143257, Temperature: 685.7546\n",
      "Step: 670, E_pot: 28.397285, E_kin: 28.916592, E_tot: 57.313877, Temperature: 692.9235\n",
      "Step: 680, E_pot: 27.93642, E_kin: 29.505323, E_tot: 57.441742, Temperature: 707.0312\n",
      "Step: 690, E_pot: 27.233719, E_kin: 30.37041, E_tot: 57.60413, Temperature: 727.76117\n",
      "Step: 700, E_pot: 26.602169, E_kin: 30.922264, E_tot: 57.524433, Temperature: 740.9851\n",
      "Step: 710, E_pot: 26.431536, E_kin: 31.076363, E_tot: 57.507896, Temperature: 744.6778\n",
      "Step: 720, E_pot: 26.744087, E_kin: 30.856579, E_tot: 57.600666, Temperature: 739.41113\n",
      "Step: 730, E_pot: 27.289867, E_kin: 30.351334, E_tot: 57.6412, Temperature: 727.3041\n",
      "Step: 740, E_pot: 27.823244, E_kin: 29.700943, E_tot: 57.524185, Temperature: 711.7188\n",
      "Step: 750, E_pot: 28.153341, E_kin: 29.001175, E_tot: 57.15452, Temperature: 694.95044\n",
      "Step: 760, E_pot: 28.170898, E_kin: 28.907064, E_tot: 57.077965, Temperature: 692.6952\n",
      "Step: 770, E_pot: 27.91226, E_kin: 28.925907, E_tot: 56.838165, Temperature: 693.1467\n",
      "Step: 780, E_pot: 27.433075, E_kin: 28.92317, E_tot: 56.356247, Temperature: 693.0812\n",
      "Step: 790, E_pot: 26.729599, E_kin: 29.664732, E_tot: 56.394333, Temperature: 710.8511\n",
      "Step: 800, E_pot: 25.824791, E_kin: 30.584433, E_tot: 56.409225, Temperature: 732.8898\n",
      "Step: 810, E_pot: 24.751232, E_kin: 31.543098, E_tot: 56.29433, Temperature: 755.86206\n",
      "Step: 820, E_pot: 23.693127, E_kin: 32.57229, E_tot: 56.265415, Temperature: 780.5244\n",
      "Step: 830, E_pot: 22.844482, E_kin: 33.523808, E_tot: 56.36829, Temperature: 803.3255\n",
      "Step: 840, E_pot: 22.272938, E_kin: 34.394184, E_tot: 56.66712, Temperature: 824.1822\n",
      "Step: 850, E_pot: 21.97895, E_kin: 34.612236, E_tot: 56.591187, Temperature: 829.40735\n",
      "Step: 860, E_pot: 21.911642, E_kin: 34.50482, E_tot: 56.416466, Temperature: 826.8334\n",
      "Step: 870, E_pot: 22.009918, E_kin: 34.24184, E_tot: 56.25176, Temperature: 820.5316\n",
      "Step: 880, E_pot: 22.1881, E_kin: 34.018684, E_tot: 56.206787, Temperature: 815.18414\n",
      "Step: 890, E_pot: 22.398027, E_kin: 33.796295, E_tot: 56.19432, Temperature: 809.85504\n",
      "Step: 900, E_pot: 22.535378, E_kin: 33.387913, E_tot: 55.92329, Temperature: 800.0691\n",
      "Step: 910, E_pot: 22.606684, E_kin: 33.27044, E_tot: 55.87712, Temperature: 797.254\n",
      "Step: 920, E_pot: 22.668877, E_kin: 33.477745, E_tot: 56.14662, Temperature: 802.2217\n",
      "Step: 930, E_pot: 22.787088, E_kin: 33.287266, E_tot: 56.074356, Temperature: 797.6573\n",
      "Step: 940, E_pot: 22.99094, E_kin: 32.790764, E_tot: 55.781704, Temperature: 785.75964\n",
      "Step: 950, E_pot: 23.255386, E_kin: 32.459774, E_tot: 55.71516, Temperature: 777.82825\n",
      "Step: 960, E_pot: 23.481316, E_kin: 32.100864, E_tot: 55.58218, Temperature: 769.2278\n",
      "Step: 970, E_pot: 23.65718, E_kin: 31.822397, E_tot: 55.479576, Temperature: 762.5549\n",
      "Step: 980, E_pot: 23.670387, E_kin: 31.471306, E_tot: 55.141693, Temperature: 754.1417\n",
      "Step: 990, E_pot: 23.794098, E_kin: 31.59299, E_tot: 55.38709, Temperature: 757.0577\n",
      "Run Time: 00:00:13\n"
     ]
    }
   ],
   "source": [
    "beg_time = time.time()\n",
    "md.run(1000, callbacks=[cb_sim, cb_h5md])\n",
    "end_time = time.time()\n",
    "used_time = end_time - beg_time\n",
    "m, s = divmod(used_time, 60)\n",
    "h, m = divmod(m, 60)\n",
    "print(\"Run Time: %02d:%02d:%02d\" % (h, m, s))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.5 ('mindspore-1.8')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "4976b8d1b143660084a7ba7652639898bf5b269ba26f14965a18b12288aa8002"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
